{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "0b01c06a",
   "metadata": {},
   "source": [
    "# NAACL'21 DLG4NLP Tutorial Demo: Text Classification"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "132dd151",
   "metadata": {},
   "source": [
    "In this tutorial demo, we will use the Graph4NLP library to build a GNN-based text classification model. The model consists of \n",
    "- graph construction module (e.g., dependency based static graph)\n",
    "- graph embedding module (e.g., Bi-Fuse GraphSAGE)\n",
    "- predictoin module (e.g., graph pooling + MLP classifier)\n",
    "\n",
    "We will use the built-in module APIs to build the model, and evaluate it on the TREC dataset."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b9b3c0d5",
   "metadata": {},
   "source": [
    "### Environment setup"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8b946a75",
   "metadata": {},
   "source": [
    "1. Create virtual environment\n",
    "```\n",
    "conda create --name graph4nlp python=3.7\n",
    "conda activate graph4nlp\n",
    "```\n",
    "\n",
    "2. Install [graph4nlp](https://github.com/graph4ai/graph4nlp) library\n",
    "- Clone the github repo\n",
    "```\n",
    "git clone -b stable https://github.com/graph4ai/graph4nlp.git\n",
    "cd graph4nlp\n",
    "```\n",
    "- Then run `./configure` (or `./configure.bat` if you are using Windows 10) to config your installation. The configuration program will ask you to specify your CUDA version. If you do not have a GPU, please choose 'cpu'.\n",
    "```\n",
    "./configure\n",
    "```\n",
    "- Finally, install the package\n",
    "```\n",
    "python setup.py install\n",
    "```\n",
    "\n",
    "3. Set up StanfordCoreNLP (for static graph construction only, unnecessary for this demo because preprocessed data is provided)\n",
    "- Download [StanfordCoreNLP](https://stanfordnlp.github.io/CoreNLP/)\n",
    "- Go to the root folder and start the server\n",
    "```\n",
    "java -mx4g -cp \"*\" edu.stanford.nlp.pipeline.StanfordCoreNLPServer -port 9000 -timeout 15000\n",
    "```"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "cd5f5762",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Using backend: pytorch\n"
     ]
    }
   ],
   "source": [
    "import os\n",
    "import time\n",
    "import datetime\n",
    "import yaml\n",
    "import numpy as np\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "from torch.utils.data import DataLoader\n",
    "import torch.optim as optim\n",
    "from torch.optim.lr_scheduler import ReduceLROnPlateau\n",
    "import torch.backends.cudnn as cudnn\n",
    "\n",
    "from graph4nlp.pytorch.datasets.trec import TrecDataset\n",
    "from graph4nlp.pytorch.modules.graph_construction import *\n",
    "from graph4nlp.pytorch.modules.graph_construction.embedding_construction import WordEmbedding\n",
    "from graph4nlp.pytorch.modules.graph_embedding import *\n",
    "from graph4nlp.pytorch.modules.prediction.classification.graph_classification import FeedForwardNN\n",
    "from graph4nlp.pytorch.modules.evaluation.base import EvaluationMetricBase\n",
    "from graph4nlp.pytorch.modules.evaluation.accuracy import Accuracy\n",
    "from graph4nlp.pytorch.modules.utils.generic_utils import EarlyStopping\n",
    "from graph4nlp.pytorch.modules.loss.general_loss import GeneralLoss\n",
    "from graph4nlp.pytorch.modules.utils.logger import Logger\n",
    "from graph4nlp.pytorch.modules.utils import constants as Constants"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "669b55bf",
   "metadata": {},
   "outputs": [],
   "source": [
    "class TextClassifier(nn.Module):\n",
    "    def __init__(self, vocab, config):\n",
    "        super(TextClassifier, self).__init__()\n",
    "        self.config = config\n",
    "        self.vocab = vocab\n",
    "        embedding_style = {'single_token_item': True if config['graph_type'] != 'ie' else False,\n",
    "                            'emb_strategy': config.get('emb_strategy', 'w2v_bilstm'),\n",
    "                            'num_rnn_layers': 1,\n",
    "                            'bert_model_name': config.get('bert_model_name', 'bert-base-uncased'),\n",
    "                            'bert_lower_case': True\n",
    "                           }\n",
    "\n",
    "        assert not (config['graph_type'] in ('node_emb', 'node_emb_refined') and config['gnn'] == 'gat'), \\\n",
    "                                'dynamic graph construction does not support GAT'\n",
    "\n",
    "        use_edge_weight = False\n",
    "        if config['graph_type'] == 'dependency':\n",
    "            self.graph_topology = DependencyBasedGraphConstruction(\n",
    "                                   embedding_style=embedding_style,\n",
    "                                   vocab=vocab.in_word_vocab,\n",
    "                                   hidden_size=config['num_hidden'],\n",
    "                                   word_dropout=config['word_dropout'],\n",
    "                                   rnn_dropout=config['rnn_dropout'],\n",
    "                                   fix_word_emb=not config['no_fix_word_emb'],\n",
    "                                   fix_bert_emb=not config.get('no_fix_bert_emb', False))\n",
    "        elif config['graph_type'] == 'constituency':\n",
    "            self.graph_topology = ConstituencyBasedGraphConstruction(\n",
    "                                   embedding_style=embedding_style,\n",
    "                                   vocab=vocab.in_word_vocab,\n",
    "                                   hidden_size=config['num_hidden'],\n",
    "                                   word_dropout=config['word_dropout'],\n",
    "                                   rnn_dropout=config['rnn_dropout'],\n",
    "                                   fix_word_emb=not config['no_fix_word_emb'],\n",
    "                                   fix_bert_emb=not config.get('no_fix_bert_emb', False))\n",
    "        elif config['graph_type'] == 'ie':\n",
    "            self.graph_topology = IEBasedGraphConstruction(\n",
    "                                   embedding_style=embedding_style,\n",
    "                                   vocab=vocab.in_word_vocab,\n",
    "                                   hidden_size=config['num_hidden'],\n",
    "                                   word_dropout=config['word_dropout'],\n",
    "                                   rnn_dropout=config['rnn_dropout'],\n",
    "                                   fix_word_emb=not config['no_fix_word_emb'],\n",
    "                                   fix_bert_emb=not config.get('no_fix_bert_emb', False))\n",
    "        elif config['graph_type'] == 'node_emb':\n",
    "            self.graph_topology = NodeEmbeddingBasedGraphConstruction(\n",
    "                                   vocab.in_word_vocab,\n",
    "                                   embedding_style,\n",
    "                                   sim_metric_type=config['gl_metric_type'],\n",
    "                                   num_heads=config['gl_num_heads'],\n",
    "                                   top_k_neigh=config['gl_top_k'],\n",
    "                                   epsilon_neigh=config['gl_epsilon'],\n",
    "                                   smoothness_ratio=config['gl_smoothness_ratio'],\n",
    "                                   connectivity_ratio=config['gl_connectivity_ratio'],\n",
    "                                   sparsity_ratio=config['gl_sparsity_ratio'],\n",
    "                                   input_size=config['num_hidden'],\n",
    "                                   hidden_size=config['gl_num_hidden'],\n",
    "                                   fix_word_emb=not config['no_fix_word_emb'],\n",
    "                                   fix_bert_emb=not config.get('no_fix_bert_emb', False),\n",
    "                                   word_dropout=config['word_dropout'],\n",
    "                                   rnn_dropout=config['rnn_dropout'])\n",
    "            use_edge_weight = True\n",
    "        elif config['graph_type'] == 'node_emb_refined':\n",
    "            self.graph_topology = NodeEmbeddingBasedRefinedGraphConstruction(\n",
    "                                    vocab.in_word_vocab,\n",
    "                                    embedding_style,\n",
    "                                    config['init_adj_alpha'],\n",
    "                                    sim_metric_type=config['gl_metric_type'],\n",
    "                                    num_heads=config['gl_num_heads'],\n",
    "                                    top_k_neigh=config['gl_top_k'],\n",
    "                                    epsilon_neigh=config['gl_epsilon'],\n",
    "                                    smoothness_ratio=config['gl_smoothness_ratio'],\n",
    "                                    connectivity_ratio=config['gl_connectivity_ratio'],\n",
    "                                    sparsity_ratio=config['gl_sparsity_ratio'],\n",
    "                                    input_size=config['num_hidden'],\n",
    "                                    hidden_size=config['gl_num_hidden'],\n",
    "                                    fix_word_emb=not config['no_fix_word_emb'],\n",
    "                                    fix_bert_emb=not config.get('no_fix_bert_emb', False),\n",
    "                                    word_dropout=config['word_dropout'],\n",
    "                                    rnn_dropout=config['rnn_dropout'])\n",
    "            use_edge_weight = True\n",
    "        else:\n",
    "            raise RuntimeError('Unknown graph_type: {}'.format(config['graph_type']))\n",
    "\n",
    "        if 'w2v' in self.graph_topology.embedding_layer.word_emb_layers:\n",
    "            self.word_emb = self.graph_topology.embedding_layer.word_emb_layers['w2v'].word_emb_layer\n",
    "        else:\n",
    "            self.word_emb = WordEmbedding(\n",
    "                            self.vocab.in_word_vocab.embeddings.shape[0],\n",
    "                            self.vocab.in_word_vocab.embeddings.shape[1],\n",
    "                            pretrained_word_emb=self.vocab.in_word_vocab.embeddings,\n",
    "                            fix_emb=not config['no_fix_word_emb'],\n",
    "                            device=config['device']).word_emb_layer\n",
    "\n",
    "        if config['gnn'] == 'gat':\n",
    "            heads = [config['gat_num_heads']] * (config['gnn_num_layers'] - 1) + [config['gat_num_out_heads']]\n",
    "            self.gnn = GAT(config['gnn_num_layers'],\n",
    "                        config['num_hidden'],\n",
    "                        config['num_hidden'],\n",
    "                        config['num_hidden'],\n",
    "                        heads,\n",
    "                        direction_option=config['gnn_direction_option'],\n",
    "                        feat_drop=config['gnn_dropout'],\n",
    "                        attn_drop=config['gat_attn_dropout'],\n",
    "                        negative_slope=config['gat_negative_slope'],\n",
    "                        residual=config['gat_residual'],\n",
    "                        activation=F.elu)\n",
    "        elif config['gnn'] == 'graphsage':\n",
    "            self.gnn = GraphSAGE(config['gnn_num_layers'],\n",
    "                        config['num_hidden'],\n",
    "                        config['num_hidden'],\n",
    "                        config['num_hidden'],\n",
    "                        config['graphsage_aggreagte_type'],\n",
    "                        direction_option=config['gnn_direction_option'],\n",
    "                        feat_drop=config['gnn_dropout'],\n",
    "                        bias=True,\n",
    "                        norm=None,\n",
    "                        activation=F.relu,\n",
    "                        use_edge_weight=use_edge_weight)\n",
    "        elif config['gnn'] == 'ggnn':\n",
    "            self.gnn = GGNN(config['gnn_num_layers'],\n",
    "                        config['num_hidden'],\n",
    "                        config['num_hidden'],\n",
    "                        config['num_hidden'],\n",
    "                        feat_drop=config['gnn_dropout'],\n",
    "                        direction_option=config['gnn_direction_option'],\n",
    "                        bias=True,\n",
    "                        use_edge_weight=use_edge_weight)\n",
    "        else:\n",
    "            raise RuntimeError('Unknown gnn type: {}'.format(config['gnn']))\n",
    "\n",
    "        self.clf = FeedForwardNN(2 * config['num_hidden'] \\\n",
    "                        if config['gnn_direction_option'] == 'bi_sep' \\\n",
    "                        else config['num_hidden'],\n",
    "                        config['num_classes'],\n",
    "                        [config['num_hidden']],\n",
    "                        graph_pool_type=config['graph_pooling'],\n",
    "                        dim=config['num_hidden'],\n",
    "                        use_linear_proj=config['max_pool_linear_proj'])\n",
    "\n",
    "        self.loss = GeneralLoss('CrossEntropy')\n",
    "\n",
    "\n",
    "    def forward(self, graph_list, tgt=None, require_loss=True):\n",
    "        # build graph topology\n",
    "        batch_gd = self.graph_topology(graph_list)\n",
    "\n",
    "        # run GNN encoder\n",
    "        self.gnn(batch_gd)\n",
    "\n",
    "        # run graph classifier\n",
    "        self.clf(batch_gd)\n",
    "        logits = batch_gd.graph_attributes['logits']\n",
    "\n",
    "        if require_loss:\n",
    "            loss = self.loss(logits, tgt)\n",
    "            return logits, loss\n",
    "        else:\n",
    "            return logits"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "eda1f453",
   "metadata": {},
   "outputs": [],
   "source": [
    "class ModelHandler:\n",
    "    def __init__(self, config):\n",
    "        super(ModelHandler, self).__init__()\n",
    "        self.config = config\n",
    "        self.logger = Logger(self.config['out_dir'], config={k:v for k, v in self.config.items() if k != 'device'}, overwrite=True)\n",
    "        self.logger.write(self.config['out_dir'])\n",
    "        self._build_device()\n",
    "        self._build_dataloader()\n",
    "        self._build_model()\n",
    "        self._build_optimizer()\n",
    "        self._build_evaluation()\n",
    "\n",
    "    def _build_device(self):\n",
    "        if not self.config['no_cuda'] and torch.cuda.is_available():\n",
    "            print('[ Using CUDA ]')\n",
    "            self.config['device'] = torch.device('cuda' if self.config['gpu'] < 0 else 'cuda:%d' % self.config['gpu'])\n",
    "            torch.cuda.manual_seed(self.config['seed'])\n",
    "            torch.cuda.manual_seed_all(self.config['seed'])\n",
    "            torch.backends.cudnn.deterministic = True\n",
    "            cudnn.benchmark = False\n",
    "        else:\n",
    "            self.config['device'] = torch.device('cpu')\n",
    "        \n",
    "    def _build_dataloader(self):\n",
    "        dynamic_init_topology_builder = None\n",
    "        if self.config['graph_type'] == 'dependency':\n",
    "            topology_builder = DependencyBasedGraphConstruction\n",
    "            graph_type = 'static'\n",
    "            merge_strategy = 'tailhead'\n",
    "        elif self.config['graph_type'] == 'constituency':\n",
    "            topology_builder = ConstituencyBasedGraphConstruction\n",
    "            graph_type = 'static'\n",
    "            merge_strategy = 'tailhead'\n",
    "        elif self.config['graph_type'] == 'ie':\n",
    "            topology_builder = IEBasedGraphConstruction\n",
    "            graph_type = 'static'\n",
    "            merge_strategy = 'global'\n",
    "        elif self.config['graph_type'] == 'node_emb':\n",
    "            topology_builder = NodeEmbeddingBasedGraphConstruction\n",
    "            graph_type = 'dynamic'\n",
    "            merge_strategy = None\n",
    "        elif self.config['graph_type'] == 'node_emb_refined':\n",
    "            topology_builder = NodeEmbeddingBasedRefinedGraphConstruction\n",
    "            graph_type = 'dynamic'\n",
    "            merge_strategy = 'tailhead'\n",
    "\n",
    "            if self.config['init_graph_type'] == 'line':\n",
    "                dynamic_init_topology_builder = None\n",
    "            elif self.config['init_graph_type'] == 'dependency':\n",
    "                dynamic_init_topology_builder = DependencyBasedGraphConstruction\n",
    "            elif self.config['init_graph_type'] == 'constituency':\n",
    "                dynamic_init_topology_builder = ConstituencyBasedGraphConstruction\n",
    "            elif self.config['init_graph_type'] == 'ie':\n",
    "                merge_strategy = 'global'\n",
    "                dynamic_init_topology_builder = IEBasedGraphConstruction\n",
    "            else:\n",
    "                raise RuntimeError('Define your own dynamic_init_topology_builder')\n",
    "        else:\n",
    "            raise RuntimeError('Unknown graph_type: {}'.format(self.config['graph_type']))\n",
    "\n",
    "        topology_subdir = '{}_graph'.format(self.config['graph_type'])\n",
    "        if self.config['graph_type'] == 'node_emb_refined':\n",
    "            topology_subdir += '_{}'.format(self.config['init_graph_type'])\n",
    "\n",
    "        dataset = TrecDataset(root_dir=self.config.get('root_dir', self.config['root_data_dir']),\n",
    "                              pretrained_word_emb_name=self.config.get('pretrained_word_emb_name', \"840B\"),\n",
    "                              merge_strategy=merge_strategy,\n",
    "                              seed=self.config['seed'],\n",
    "                              thread_number=4,\n",
    "                              port=9000,\n",
    "                              timeout=15000,\n",
    "                              word_emb_size=300,\n",
    "                              graph_type=graph_type,\n",
    "                              topology_builder=topology_builder,\n",
    "                              topology_subdir=topology_subdir,\n",
    "                              dynamic_graph_type=self.config['graph_type'] if \\\n",
    "                                  self.config['graph_type'] in ('node_emb', 'node_emb_refined') else None,\n",
    "                              dynamic_init_topology_builder=dynamic_init_topology_builder,\n",
    "                              dynamic_init_topology_aux_args={'dummy_param': 0})\n",
    "\n",
    "        self.train_dataloader = DataLoader(dataset.train, batch_size=self.config['batch_size'], shuffle=True,\n",
    "                                           num_workers=self.config['num_workers'],\n",
    "                                           collate_fn=dataset.collate_fn)\n",
    "        if hasattr(dataset, 'val')==False:\n",
    "            dataset.val = dataset.test\n",
    "        self.val_dataloader = DataLoader(dataset.val, batch_size=self.config['batch_size'], shuffle=False,\n",
    "                                          num_workers=self.config['num_workers'],\n",
    "                                          collate_fn=dataset.collate_fn)\n",
    "        self.test_dataloader = DataLoader(dataset.test, batch_size=self.config['batch_size'], shuffle=False,\n",
    "                                          num_workers=self.config['num_workers'],\n",
    "                                          collate_fn=dataset.collate_fn)\n",
    "        self.vocab = dataset.vocab_model\n",
    "        self.config['num_classes'] = dataset.num_classes\n",
    "        self.num_train = len(dataset.train)\n",
    "        self.num_val = len(dataset.val)\n",
    "        self.num_test = len(dataset.test)\n",
    "        print('Train size: {}, Val size: {}, Test size: {}'\n",
    "            .format(self.num_train, self.num_val, self.num_test))\n",
    "        self.logger.write('Train size: {}, Val size: {}, Test size: {}'\n",
    "            .format(self.num_train, self.num_val, self.num_test))\n",
    "\n",
    "    def _build_model(self):\n",
    "        self.model = TextClassifier(self.vocab, self.config).to(self.config['device'])\n",
    "\n",
    "    def _build_optimizer(self):\n",
    "        parameters = [p for p in self.model.parameters() if p.requires_grad]\n",
    "        self.optimizer = optim.Adam(parameters, lr=self.config['lr'])\n",
    "        self.stopper = EarlyStopping(os.path.join(self.config['out_dir'], Constants._SAVED_WEIGHTS_FILE), patience=self.config['patience'])\n",
    "        self.scheduler = ReduceLROnPlateau(self.optimizer, mode='max', factor=self.config['lr_reduce_factor'], \\\n",
    "            patience=self.config['lr_patience'], verbose=True)\n",
    "\n",
    "    def _build_evaluation(self):\n",
    "        self.metric = Accuracy(['accuracy'])\n",
    "\n",
    "    def train(self):\n",
    "        dur = []\n",
    "        for epoch in range(self.config['epochs']):\n",
    "            self.model.train()\n",
    "            train_loss = []\n",
    "            train_acc = []\n",
    "            t0 = time.time()\n",
    "            for i, data in enumerate(self.train_dataloader):\n",
    "                tgt = data['tgt_tensor'].to(self.config['device'])\n",
    "                data['graph_data'] = data['graph_data'].to(self.config['device'])\n",
    "                logits, loss = self.model(data['graph_data'], tgt, require_loss=True)\n",
    "\n",
    "                # add graph regularization loss if available\n",
    "                if data['graph_data'].graph_attributes.get('graph_reg', None) is not None:\n",
    "                    loss = loss + data['graph_data'].graph_attributes['graph_reg']\n",
    "\n",
    "                self.optimizer.zero_grad()\n",
    "                loss.backward()\n",
    "                self.optimizer.step()\n",
    "                train_loss.append(loss.item())\n",
    "\n",
    "                pred = torch.max(logits, dim=-1)[1].cpu()\n",
    "                train_acc.append(self.metric.calculate_scores(ground_truth=tgt.cpu(), predict=pred.cpu(), zero_division=0)[0])\n",
    "                dur.append(time.time() - t0)\n",
    "\n",
    "            val_acc = self.evaluate(self.val_dataloader)\n",
    "            self.scheduler.step(val_acc)\n",
    "            print('Epoch: [{} / {}] | Time: {:.2f}s | Loss: {:.4f} | Train Acc: {:.4f} | Val Acc: {:.4f}'.\n",
    "              format(epoch + 1, self.config['epochs'], np.mean(dur), np.mean(train_loss), np.mean(train_acc), val_acc))\n",
    "            self.logger.write('Epoch: [{} / {}] | Time: {:.2f}s | Loss: {:.4f} | Train Acc: {:.4f} | Val Acc: {:.4f}'.\n",
    "                        format(epoch + 1, self.config['epochs'], np.mean(dur), np.mean(train_loss), np.mean(train_acc), val_acc))\n",
    "\n",
    "            if self.stopper.step(val_acc, self.model):\n",
    "                break\n",
    "\n",
    "        return self.stopper.best_score\n",
    "\n",
    "    def evaluate(self, dataloader):\n",
    "        self.model.eval()\n",
    "        with torch.no_grad():\n",
    "            pred_collect = []\n",
    "            gt_collect = []\n",
    "            for i, data in enumerate(dataloader):\n",
    "                tgt = data['tgt_tensor'].to(self.config['device'])\n",
    "                data['graph_data'] = data['graph_data'].to(self.config[\"device\"])\n",
    "                logits = self.model(data['graph_data'], require_loss=False)\n",
    "                pred_collect.append(logits)\n",
    "                gt_collect.append(tgt)\n",
    "\n",
    "            pred_collect = torch.max(torch.cat(pred_collect, 0), dim=-1)[1].cpu()\n",
    "            gt_collect = torch.cat(gt_collect, 0).cpu()\n",
    "            score = self.metric.calculate_scores(ground_truth=gt_collect, predict=pred_collect, zero_division=0)[0]\n",
    "\n",
    "            return score\n",
    "\n",
    "    def test(self):\n",
    "        # restored best saved model\n",
    "        self.stopper.load_checkpoint(self.model)\n",
    "\n",
    "        t0 = time.time()\n",
    "        acc = self.evaluate(self.test_dataloader)\n",
    "        dur = time.time() - t0\n",
    "        print('Test examples: {} | Time: {:.2f}s |  Test Acc: {:.4f}'.\n",
    "          format(self.num_test, dur, acc))\n",
    "        self.logger.write('Test examples: {} | Time: {:.2f}s |  Test Acc: {:.4f}'.\n",
    "          format(self.num_test, dur, acc))\n",
    "\n",
    "        return acc"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "b2bff971",
   "metadata": {},
   "outputs": [],
   "source": [
    "def print_config(config):\n",
    "    print('**************** MODEL CONFIGURATION ****************')\n",
    "    for key in sorted(config.keys()):\n",
    "        val = config[key]\n",
    "        keystr = '{}'.format(key) + (' ' * (24 - len(key)))\n",
    "        print('{} -->   {}'.format(keystr, val))\n",
    "    print('**************** MODEL CONFIGURATION ****************')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "31b9931a",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "**************** MODEL CONFIGURATION ****************\n",
      "batch_size               -->   50\n",
      "dataset                  -->   trec\n",
      "epochs                   -->   500\n",
      "gat_attn_dropout         -->   None\n",
      "gat_negative_slope       -->   None\n",
      "gat_num_heads            -->   None\n",
      "gat_num_out_heads        -->   None\n",
      "gat_residual             -->   False\n",
      "gl_connectivity_ratio    -->   None\n",
      "gl_epsilon               -->   None\n",
      "gl_metric_type           -->   None\n",
      "gl_num_heads             -->   1\n",
      "gl_num_hidden            -->   300\n",
      "gl_smoothness_ratio      -->   None\n",
      "gl_sparsity_ratio        -->   None\n",
      "gl_top_k                 -->   None\n",
      "gnn                      -->   graphsage\n",
      "gnn_direction_option     -->   bi_fuse\n",
      "gnn_dropout              -->   0.3\n",
      "gnn_num_layers           -->   1\n",
      "gpu                      -->   0\n",
      "graph_pooling            -->   avg_pool\n",
      "graph_type               -->   dependency\n",
      "graphsage_aggreagte_type -->   lstm\n",
      "init_adj_alpha           -->   None\n",
      "init_graph_type          -->   None\n",
      "lr                       -->   0.001\n",
      "lr_patience              -->   2\n",
      "lr_reduce_factor         -->   0.5\n",
      "max_pool_linear_proj     -->   False\n",
      "no_cuda                  -->   False\n",
      "no_fix_word_emb          -->   False\n",
      "node_edge_emb_strategy   -->   mean\n",
      "num_hidden               -->   300\n",
      "num_workers              -->   1\n",
      "out_dir                  -->   out/trec/graphsage_bi_fuse_dependency_ckpt\n",
      "patience                 -->   10\n",
      "pretrained_word_emb_name -->   840B\n",
      "rnn_dropout              -->   0.1\n",
      "root_data_dir            -->   ../data/trec\n",
      "seed                     -->   1234\n",
      "seq_info_encode_strategy -->   bilstm\n",
      "val_split_ratio          -->   0.2\n",
      "word_dropout             -->   0.4\n",
      "**************** MODEL CONFIGURATION ****************\n"
     ]
    }
   ],
   "source": [
    "# config setup\n",
    "config_file = '../config/trec/graphsage_bi_fuse_static_dependency.yaml'\n",
    "config = yaml.load(open(config_file, 'r'), Loader=yaml.FullLoader)\n",
    "print_config(config)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a5087ca0",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "out/trec/graphsage_bi_fuse_dependency_ckpt_1623813162.034149\n",
      "Loading pre-built vocab model stored in ../data/trec/processed/dependency_graph/vocab.pt\n",
      "Train size: 5452, Val size: 500, Test size: 500\n",
      "[ Fix word embeddings ]\n",
      "Epoch: [1 / 500] | Time: 15.75s | Loss: 1.1737 | Train Acc: 0.5293 | Val Acc: 0.7460\n",
      "Saved model to out/trec/graphsage_bi_fuse_dependency_ckpt_1623813162.034149/params.saved\n"
     ]
    }
   ],
   "source": [
    "# run model\n",
    "np.random.seed(config['seed'])\n",
    "torch.manual_seed(config['seed'])\n",
    "\n",
    "ts = datetime.datetime.now().timestamp()\n",
    "config['out_dir'] += '_{}'.format(ts)\n",
    "print('\\n' + config['out_dir'])\n",
    "\n",
    "runner = ModelHandler(config)\n",
    "t0 = time.time()\n",
    "\n",
    "val_acc = runner.train()\n",
    "test_acc = runner.test()\n",
    "\n",
    "runtime = time.time() - t0\n",
    "print('Total runtime: {:.2f}s'.format(runtime))\n",
    "runner.logger.write('Total runtime: {:.2f}s\\n'.format(runtime))\n",
    "runner.logger.close()\n",
    "\n",
    "print('val acc: {}, test acc: {}'.format(val_acc, test_acc))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c537b575",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
